package keymutex

import (
	"sync"
	"sync/atomic"
)

var mutexPool = sync.Pool{
	New: func() any {
		return &sync.Mutex{}
	},
}

type LockObj struct {
	Lock *sync.Mutex
	Num  int64
}

// KeyMutex 更细的粒度、更高的并发性能
// 通过sync.pool实现子锁对象的复用，减少内存分配
type KeyMutex struct {
	globalMutex sync.Mutex
	locks       map[string]*LockObj
}

func NewKeyMutex() *KeyMutex {
	return &KeyMutex{
		locks: make(map[string]*LockObj),
	}
}

func (l *KeyMutex) getLock(key string) *sync.Mutex {
	l.globalMutex.Lock()
	defer l.globalMutex.Unlock()

	if lockObj, ok := l.locks[key]; ok {
		/* 锁复用 */
		atomic.AddInt64(&lockObj.Num, 1)
		return lockObj.Lock
	}
	lock := mutexPool.Get().(*sync.Mutex)
	l.locks[key] = &LockObj{
		Lock: lock,
		Num:  1,
	}
	return lock
}

// Lock 加锁
func (l *KeyMutex) Lock(key string) {
	l.getLock(key).Lock()
}

// Unlock 解锁
func (l *KeyMutex) Unlock(key string) {
	l.globalMutex.Lock()
	defer l.globalMutex.Unlock()

	l.locks[key].Lock.Unlock()
	atomic.AddInt64(&l.locks[key].Num, -1)
	/* clean */
	if l.locks[key].Num <= 0 {
		mutexPool.Put(l.locks[key].Lock)
		delete(l.locks, key)
	}
}
